

import os
import copy
import json
import traceback
from torch import nn
import torch,gym
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import threading
from gym.wrappers import ResizeObservation, TransformObservation, GrayScaleObservation
from torch import device

import prototype_manager
from compute_distribution import wasserstein_distance_tasks
from continual_rl.experiments.tasks.make_minihack_task import make_minihack

from continual_rl.experiments.experiment import raw_experiment
from continual_rl.policies.policy_base import PolicyBase
from continual_rl.policies.impala.impala_environment_runner import ImpalaEnvironmentRunner
from continual_rl.policies.impala.impala_policy import ImpalaPolicy
from continual_rl.policies.clear.clear_monobeast import ClearMonobeast
from continual_rl.policies.DSNet.kd import KnowledgeDistiller
from continual_rl.utils.argparse_manager import ArgparseManager2
from continual_rl.utils.utils import Utils
from continual_rl.policies.impala.torchbeast.core.environment import Environment
from continual_rl.policies.DSNet.node_viz_singleton import NodeVizSingleton
from sklearn.metrics.pairwise import cosine_similarity
from continual_rl.experiments.tasks.make_atari_task import wrap_deepmind, make_atari

from ppo_states import ppo_state_custom

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class DSNetEnvironmentRunner(ImpalaEnvironmentRunner):
    def __init__(self, config, policy):
        super().__init__(config, policy)
        self._policy = policy
        self._cached_environment_runners = {}
        self._total_timesteps = 0
        self._timesteps_since_update = 0

    def collect_data(self, task_spec):
        active_node, similarity_score, similar_task_id = self._policy.get_active_node(task_spec)
        self._logger.info(
            f"Collecting data with active node {active_node.unique_id} (similarity: {similarity_score:.4f})")

        if active_node.unique_id not in self._cached_environment_runners:
            self._cached_environment_runners[active_node.unique_id] = active_node.get_environment_runner(task_spec)

        self._logger.debug(f"Starting data collection for task {int(task_spec.task_id[-1])}. "
                           f"Number of actors: {self._config.num_actors}, unroll length: {self._config.unroll_length}")

        timesteps, all_env_data, rewards_to_report, logs_to_report = self._cached_environment_runners[
            active_node.unique_id].collect_data(task_spec)


        self._logger.info(f"Collected data: {len(all_env_data)} environment entries")
        self._total_timesteps += timesteps
        self._timesteps_since_update += timesteps
        active_node.usage_count += timesteps


        if self._config.train_all:
            self._policy.train_all(task_spec)

        # Check if anything needs updating
        if not task_spec.eval_mode and not self._config.static_ensemble:
            self._policy.ensure_max_nodes(task_spec)

        if self._config.use_slow_critic and self._timesteps_since_update > self._config.slow_critic_update_cadence:
            self._timesteps_since_update = 0
            active_node.slow_critic.update_parameters(active_node.impala_trainer.actor_model)

        suffix = f"_eval_{int(task_spec.task_id[-1])}" if task_spec.eval_mode else ""
        logs_to_report.append({'type': 'scalar', 'tag': f'num_DSNet_nodes', 'value': len(self._policy._nodes)})
        logs_to_report.append({'type': 'scalar', 'tag': f'active_node_id{suffix}', 'value': active_node.unique_id})
        logs_to_report.append({'type': 'scalar', 'tag': f'task_similarity{suffix}', 'value': similarity_score})

        return timesteps, all_env_data, rewards_to_report, logs_to_report

    def cleanup(self, task_spec):
        if not task_spec.eval_mode:
            for node in self._policy._nodes:
                node.impala_trainer.cleanup()
        del self._result_generators


class DSNetPolicy(PolicyBase):
    def __init__(self, config, observation_space, action_spaces):
        super().__init__(config)
        self._config = config
        self._observation_space = observation_space
        self._action_spaces = action_spaces
        self._nodes = []
        self._task_features = {}  # Store task features: {task_id: feature_matrix}
        self._task_node_mapping = {}  # Map task to node: {task_id: node_id}
        self._node_tasks = {}  # Map node to tasks: {node_id: set(task_ids)}
        self._similarity_threshold = self._config.similarity_threshold  # Threshold for task similarity
        self._node_features = {}  # 新增：存储节点特征 {node_id: feature_matrix}
        self.test_task_features = {}
        # 添加知识蒸馏器
        # 蒸馏配置
        # 添加Wasserstein距离配置参数
        self.config.wasserstein_scale = getattr(config, 'wasserstein_scale', 1.0)

        self.distiller_config = {
            'alpha': getattr(config, 'distill_alpha', 0.3),  # 如果属性不存在，默认0.3
            'k': getattr(config, 'distill_neighbors', 2),
            'update_freq': getattr(config, 'distill_update_freq', 10000)
        }
        self.distiller = None
        # self._init_distiller()


    def _init_distiller(self):
        if len(self._nodes) >= 2:
            # 传递policy引用以便访问节点特征
            self.distiller = KnowledgeDistiller(
                self._nodes,
                self.distiller_config,
                self  # 传递当前policy实例
            )

    def _update_distiller(self):
        if len(self._nodes) >= 2:
            self.distiller = KnowledgeDistiller(
                self._nodes,
                self.distiller_config,
                self  # 传递当前policy实例
            )

    @property
    def _logger(self):
        logger = Utils.create_logger(f"{self._config.output_dir}/DSNet.log")
        return logger

    def _get_canonical_obs(self, task_spec):
        dummy_env = Environment(Utils.make_env(task_spec.env_spec)[0])
        obs = dummy_env.initial()
        return obs

    def _add_replay_buffer(self, source_node, target_node):
        num_actors = len(source_node.impala_trainer._replay_buffers['frame'])
        num_buffers = len(source_node.impala_trainer._replay_buffers['frame'][0])
        for actor_index in range(num_actors):
            for buffer_id in range(num_buffers):
                new_buffers = source_node.impala_trainer._replay_buffers
                if new_buffers['reservoir_val'][actor_index][buffer_id] > 0:
                    actor_buffers = {key: new_buffers[key][actor_index][buffer_id] for key in new_buffers.keys()}
                    target_node.impala_trainer.on_act_unroll_complete(task_flags=None, actor_index=actor_index,
                                                                      agent_output=None,
                                                                      env_output=None, new_buffers=actor_buffers)

    def _duplicate_node(self, source_node):
        new_node = DSNetNode(self._config, self._observation_space, self._action_spaces, self)
        new_node.impala_trainer.actor_model.load_state_dict(source_node.impala_trainer.actor_model.state_dict())
        new_node.impala_trainer.learner_model.load_state_dict(source_node.impala_trainer.actor_model.state_dict())

        if self._config.duplicate_optimizer:
            new_node.impala_trainer.optimizer.load_state_dict(source_node.impala_trainer.optimizer.state_dict())

        if self._config.use_slow_critic:
            new_node.slow_critic.load_state_dict(source_node.slow_critic.state_dict())
            new_node.prototype.load_state_dict(source_node.slow_critic.state_dict())
        else:
            new_node.prototype.load_state_dict(source_node.impala_trainer.actor_model.state_dict())

        if self._config.create_adds_replay:
            self._add_replay_buffer(source_node, new_node)

        return new_node

    def _duplicate_train_node(self, source_node):
        new_node = DSNetNode(self._config, self._observation_space, self._action_spaces, self)

        # 复制整个模型作为基础
        full_state_dict = source_node.impala_trainer.actor_model.state_dict()
        new_node.impala_trainer.actor_model.load_state_dict(full_state_dict)

        # 核心修改：分层处理（带噪声初始化顶层）
        for name, param in new_node.impala_trainer.actor_model.named_parameters():
            # 低层卷积：直接使用源节点权重（不添加噪声）
            if name.startswith('_conv_net._conv_net'):
                continue  # 已经通过完整复制加载

            # 中层全连接：保持复制，不添加噪声
            elif name.startswith('_conv_net._post_flatten'):
                continue  # 保持原样

            # 策略头：带噪声初始化
            elif name.startswith('policy'):
                if 'weight' in name:
                    nn.init.kaiming_normal_(param.data, mode='fan_in', nonlinearity='relu')
                    param.data += torch.randn_like(param.data) * self._config.noise_scale
                elif 'bias' in name:
                    nn.init.zeros_(param.data)  # 保持无噪声初始化

            # 价值基线：带噪声初始化
            elif name.startswith('baseline'):
                if 'weight' in name:
                    nn.init.kaiming_normal_(param.data, mode='fan_in', nonlinearity='relu')
                    param.data += torch.randn_like(param.data) * self._config.noise_scale
                elif 'bias' in name:
                    nn.init.zeros_(param.data)

        # 冻结底层卷积（如果需要）
        if self._config.freeze_lower_layers:
            for name, param in new_node.impala_trainer.actor_model.named_parameters():
                if name.startswith('_conv_net._conv_net'):
                    param.requires_grad = False

        # 处理learner_model（如果有）
        if hasattr(new_node.impala_trainer, 'learner_model'):
            new_node.impala_trainer.learner_model.load_state_dict(
                new_node.impala_trainer.actor_model.state_dict()
            )

        # 复制优化器状态（只复制非冻结层）
        if self._config.duplicate_optimizer and hasattr(source_node.impala_trainer, 'optimizer'):
            new_opt_state = new_node.impala_trainer.optimizer.state_dict()
            source_opt_state = source_node.impala_trainer.optimizer.state_dict()

            # 只复制策略头和价值基线的优化器状态
            for name in full_state_dict.keys():
                if name.startswith('policy') or name.startswith('baseline'):
                    param_id = id(new_node.impala_trainer.actor_model.state_dict()[name])
                    if param_id in source_opt_state['state']:
                        new_opt_state['state'][param_id] = source_opt_state['state'][param_id]

            new_node.impala_trainer.optimizer.load_state_dict(new_opt_state)

        # 处理慢速评论家和原型网络
        # [保持原有逻辑，但添加类似的分层处理]

        return new_node




    def collect_task_features(self, task_spec):
        """
        Collect and store features for a new task by creating a temporary environment.
        This version handles various environment creation patterns and LazyFrames issues.
        """
        
        # task_spec.task_id = int(task_spec.task_id[-1])
        self._logger.info(f"Collecting features for task {int(task_spec.task_id[-1])}")
        features = []
        # manager = ArgparseManager2()
        # experiment, policy, r_experiment = manager.run()
        # env_names = raw_experiment['minihack_nav_paired_2_cycles']
        env_names = raw_experiment['atari_6_tasks_5_cycles']
        # print(task_spec)
        print(env_names[int(task_spec.task_id[-1])])

        # env = gym.make(env_names[task_spec.task_id])
        # print(env_names[task_spec.task_id])

        env = wrap_deepmind(make_atari(env_names[int(task_spec.task_id[-1])], max_episode_steps=None, full_action_space=None),
            clip_rewards=False,  # If policies need to clip rewards, they should handle it themselves
            frame_stack=False,  # Handled separately
            scale=False,
        )

        env = TransformObservation(env, lambda obs: np.transpose(obs, (2, 0, 1)))  # 变为 (1, 84, 84)

        # print(env)
        state = env.reset()
        # print(state.shape)1

        extract_features = prototype_manager.AtariResNetFeatureExtractor().to(
            self.config.device).eval()  # 设为评估模式，避免 BatchNorm 变化
        task_states = ppo_state_custom(env, total_timesteps=500)
        # taskstates = [s for s, a in task_states]

        # 将状态列表转换为 PyTorch 张量
        states_tensor = torch.tensor(np.array(task_states), dtype=torch.float32)
        states_tensor = states_tensor / 255.0
        print(states_tensor.shape)
        # with torch.no_grad():
        #     states_tensor = extract_features(states_tensor.to(device))
        states_tensor = prototype_manager.extract_features_in_batches(states_tensor, 100, extract_features,
                                                                      self.config.device)
        print("提取的特征形状:", states_tensor.shape)  # 预期输出: (500, 128)

        if not task_spec.eval_mode:
            self._task_features[int(task_spec.task_id[-1])] = states_tensor
        if task_spec.eval_mode:
            self.test_task_features.clear()  # 清空字典
            self.test_task_features[int(task_spec.task_id[-1])] = states_tensor

        return states_tensor


    def compute_task_similarity(self, new_task_id, is_test_task):
        """计算新任务与所有节点的相似度，返回(最小相似度, 最相似节点, 相似度字典)"""
        # 获取新任务特征
        if is_test_task:
            new_features = self.test_task_features[new_task_id]
        else:
            new_features = self._task_features[new_task_id]

        min_similarity = float('inf')
        similar_node = None
        similarity_scores = {}  # 存储所有节点相似度的字典 {node_id: score}

        for node in self._nodes:
            if node.unique_id not in self._node_features:
                continue

            # 计算Wasserstein距离作为相似度得分
            similarity = wasserstein_distance_tasks(
                new_features,
                self._node_features[node.unique_id]
            )

            # 记录该节点的相似度
            similarity_scores[node.unique_id] = similarity

            # 更新最小相似度节点
            if similarity < min_similarity:
                min_similarity = similarity
                similar_node = node

        return min_similarity, similar_node, similarity_scores


    def train_all(self, task_flags):
        new_node = self._nodes[-1]
        self._train(new_node, task_flags)


    def _train(self, node, task_flags):
        """修改后的训练方法，不直接计算蒸馏损失"""
        batch = node.impala_trainer.get_batch_for_training(None, store_for_loss=False)

        if batch is not None:
            # 设置蒸馏参数（不计算损失）
            if self.distiller is not None and len(self._nodes) >= 2:
                node.impala_trainer.set_distillation_params(
                    self.distiller,
                    node,
                    self._task_features
                )

            initial_agent_state = None
            stats = node.impala_trainer.learn(
                model_flags=self._config,
                task_flags=task_flags,
                actor_model=node.impala_trainer.actor_model,
                learner_model=node.impala_trainer.learner_model,
                batch=batch,
                initial_agent_state=initial_agent_state,
                optimizer=node.impala_trainer.optimizer,
                scheduler=node.impala_trainer._scheduler,
                lock=threading.Lock()
            )

            # 清理蒸馏参数
            node.impala_trainer.set_distillation_params(None, None, None)

            self._logger.info(f"Completed training step with distillation")
            return stats

        return None


    def get_active_node(self, task_spec):
        """
        Select active node based on task similarity
        Returns selected node, similarity score, and similar task id
        """
        # Create first node if none exists
        # self.collect_task_features(task_spec)
        # task_spec.task_id = int(task_spec.task_id[-1])

        if len(self._nodes) == 0:
            self._logger.info("Creating first node")
            new_node = DSNetNode(self._config, self._observation_space, self._action_spaces, self)
            self._nodes.append(new_node)

            self._node_features[new_node.unique_id] = self._task_features[int(task_spec.task_id[-1])]
            return new_node, 1.0, int(task_spec.task_id[-1])

        # 计算相似度
        if task_spec.eval_mode:
            min_similarity, similar_node,similarity_scores = self.compute_task_similarity(int(task_spec.task_id[-1]), is_test_task=True)
            self._logger.info(f"Using node {similar_node.unique_id} for task {int(task_spec.task_id[-1])}")

            # # 创建新节点（使用相似度最高的前3个节点融合）
            # top_nodes = sorted(self._nodes,
            #                    key=lambda n: similarity_scores[n.unique_id])[:2]
            # similar_node = self._duplicate_multi_node(top_nodes, similarity_scores=similarity_scores)


        if not task_spec.eval_mode:
            min_similarity, similar_node,similarity_scores = self.compute_task_similarity(int(task_spec.task_id[-1]), is_test_task=False)

        # 根据节点数动态调整阈值
        # threshold = self._similarity_threshold * (1 + len(self._nodes) / self._config.max_nodes)
        # 复用相似节点（训练模式）
        if not task_spec.eval_mode and min_similarity < self._similarity_threshold:
            self._logger.info(f"Reusing node {similar_node.unique_id} for task {int(task_spec.task_id[-1])}")

            # # 混合replay buffer (1:1)
            # if self._config.mix_replay_buffers:
            #     self._mix_replay_buffers(similar_node, task_spec)

            # 更新节点特征（平均）
            self._update_node_features(similar_node, int(task_spec.task_id[-1]))

            return similar_node, min_similarity, int(task_spec.task_id[-1])

        # 增加新节点（训练模式）
        if not task_spec.eval_mode and min_similarity > self._similarity_threshold:
            # 创建新节点

            # new_node = SaneNode(self._config, self._observation_space, self._action_spaces, self)
            new_node = self._duplicate_train_node(similar_node)
            self._nodes.append(new_node)
            self._node_features[new_node.unique_id] = self._task_features[int(task_spec.task_id[-1])]
            self._logger.info(f"Creating new node {new_node.unique_id} for task {int(task_spec.task_id[-1])}")

            # 存储节点特征
            self._update_distiller()  # 创建新节点后更新蒸馏器
            # 确保蒸馏器初始化
            self._init_distiller()

            return new_node, min_similarity, int(task_spec.task_id[-1])

        # #（测试模式）
        # if task_spec.eval_mode :
        #     # self.collect_task_features(task_spec)
        #     _, similar_node = self.compute_task_similarity(task_spec.task_id, is_test_task = True)
        return similar_node, min_similarity, int(task_spec.task_id[-1])


    def _update_node_features(self, node, new_task_id):
        """更新节点特征为两个任务的平均"""
        node_features = self._node_features[node.unique_id]
        new_features = self._task_features[new_task_id]

        # 计算平均特征
        avg_features = (node_features + new_features) / 2.0
        self._node_features[node.unique_id] = avg_features

    def update_available_nodes(self, task_spec, total_timesteps, active_node):
        """
        Create new nodes based on task similarity
        """
        # Only create new nodes during training

        # task_spec.task_id = int(task_spec.task_id[-1])
        if task_spec.eval_mode:
            return

        # For new tasks that have collected features
        if int(task_spec.task_id[-1]) in self._task_features:
            min_similarity, similar_task_id,similarity_scores = self.compute_task_similarity(int(task_spec.task_id[-1]), is_test_task=False)

            # Create new node if similarity is below threshold
            if min_similarity > self._similarity_threshold:
                self._logger.info(
                    f"Creating new node for task {int(task_spec.task_id[-1])} (similarity: {min_similarity:.4f} > {self._similarity_threshold})")
                new_node = self._duplicate_node(active_node)
                self._nodes.append(new_node)

                # Map task to new node
                self._task_node_mapping[int(task_spec.task_id[-1])] = new_node.unique_id

                if self._config.visualize_nodes:
                    NodeVizSingleton.instance().create_node(self._config.output_dir, new_node.unique_id)
                    NodeVizSingleton.instance().register_created_from(self._config.output_dir, new_node.unique_id,
                                                                      active_node.unique_id)
            else:
                # Use existing node for similar task
                self._logger.info(
                    f"Using existing node for task {int(task_spec.task_id[-1])} (similarity: {min_similarity:.4f})")
                self._task_node_mapping[int(task_spec.task_id[-1])] = self._task_node_mapping[similar_task_id]

    def _get_closest_nodes(self, mergeable_nodes):
        policies = torch.stack([node.get_merge_metric() for node in mergeable_nodes])

        difference = policies.unsqueeze(1) - policies  # Taking advantage of the auto-dim matching thing.
        square_distances = (difference ** 2).sum(dim=-1).detach().cpu()

        side_length = square_distances.shape[0]
        indices = np.argsort(square_distances, axis=None)
        indices_x = indices // side_length
        indices_y = indices % side_length

        hypo_index = 0
        hypo_indices = []
        num_indices = 1  # Vestigial, only allowing 1 right now

        # Gather the num_indices entries with the smallest values that aren't on the diagonal
        for _ in range(num_indices):
            # Ignore diagonals
            while indices_x[hypo_index] == indices_y[hypo_index] and hypo_index < len(indices_x):
                hypo_index += 1

            if hypo_index < len(indices_x):
                hypo_indices.append(hypo_index)

            hypo_index += 1

        final_hypo_index = hypo_indices[0]
        assert final_hypo_index in hypo_indices, "Somehow picked a bad index"

        selected_x = indices_x[final_hypo_index]
        selected_y = indices_y[final_hypo_index]

        return mergeable_nodes[selected_x], mergeable_nodes[selected_y]

    def ensure_max_nodes(self, task_flags):
        while len(self._nodes) > self._config.max_nodes:
            num_mergeable = int(self._config.fraction_of_nodes_mergeable * self._config.max_nodes)
            node_to_keep, node_to_remove = self._get_closest_nodes(self._nodes[:num_mergeable])

            # Using min (non-zero) reservoir value as a proxy for usage count, so we keep the node with the higher value
            if (self._config.keep_larger_reservoir_val_in_merge and
                node_to_remove.impala_trainer.get_min_reservoir_val_greater_than_zero() > node_to_keep.impala_trainer.get_min_reservoir_val_greater_than_zero()) or \
                    (self._config.usage_count_based_merge and node_to_remove.usage_count > node_to_keep.usage_count):
                node_to_keep, node_to_remove = node_to_remove, node_to_keep

            node_to_keep.usage_count += node_to_remove.usage_count
            self._add_replay_buffer(node_to_remove, node_to_keep)
            self._train(node_to_keep, task_flags)
            self._nodes.remove(node_to_remove)

            self._logger.info(f"Deleting resources for node {node_to_remove.unique_id}")
            node_to_remove.impala_trainer.permanent_delete()
            self._logger.info("Deletion complete")

            if self._config.visualize_nodes:
                NodeVizSingleton.instance().merge_node(self._config.output_dir, node_to_remove.unique_id,
                                                       node_to_keep.unique_id)
            # 节点变化后更新蒸馏器
            self._update_distiller()

    def get_environment_runner(self, task_spec):
        states_tensor = self.collect_task_features(task_spec)
        return DSNetEnvironmentRunner(self._config, self)

    def compute_action(self, observation, task_id, action_space_id, last_timestep_data, eval_mode):
        pass

    def load(self, output_path_dir):
        node_metadata = os.path.join(output_path_dir, "metadata.json")

        if os.path.exists(node_metadata):
            self._nodes = []
            with open(node_metadata, "r") as metadata_file:
                all_node_data = json.load(metadata_file)

            for unique_id, node_data in all_node_data.items():
                loaded_node = DSNetNode(self._config, self._observation_space, self._action_spaces, self, int(unique_id))
                loaded_node.load(node_data["path"])
                loaded_node.usage_count = node_data["usage_count"]
                self._nodes.append(loaded_node)

        # Load task features and mappings
        feature_path = os.path.join(output_path_dir, "task_features.npz")
        if os.path.exists(feature_path):
            data = np.load(feature_path)
            self._task_features = dict(data)

        mapping_path = os.path.join(output_path_dir, "task_mapping.json")
        if os.path.exists(mapping_path):
            with open(mapping_path, "r") as f:
                mapping_data = json.load(f)
                self._task_node_mapping = mapping_data["task_node_mapping"]
                self._node_tasks = mapping_data["node_tasks"]

    def save(self, output_path_dir, cycle_id, task_id, task_total_steps):
        node_data = {}
        for node in self._nodes:
            node_path = os.path.join(output_path_dir, "node_save_data", f"node_{node.unique_id}")
            os.makedirs(node_path, exist_ok=True)
            node.save(node_path, cycle_id, task_id, task_total_steps)
            node_data[node.unique_id] = {"path": node_path, "usage_count": node.usage_count}

        node_metadata_path = os.path.join(output_path_dir, "metadata.json")
        with open(node_metadata_path, "w+") as metadata_file:
            json.dump(node_data, metadata_file)

        # # Save task features
        # feature_path = os.path.join(output_path_dir, "task_features.npz")
        # np.savez(feature_path, **self._task_features)

        # Save task-node mappings
        mapping_path = os.path.join(output_path_dir, "task_mapping.json")
        with open(mapping_path, "w+") as f:
            json.dump({
                "task_node_mapping": self._task_node_mapping,
                "node_tasks": self._node_tasks
            }, f)

    def train(self, storage_buffer):
        pass

    def cleanup(self):
        """释放资源，如关闭环境等"""
        self._logger.info("Cleaning up policy resources...")
        # 关闭所有节点中的环境或其他需要关闭的资源
        for node in self._nodes:
            if hasattr(node, 'impala_trainer') and node.impala_trainer is not None:
                # 假设impala_trainer有close方法
                node.impala_trainer.close()
            # 如果节点有其他资源需要释放，也在此处处理
        # 如果有全局的环境运行器（environment_runner）也需要关闭
        if hasattr(self, '_environment_runner'):
            self._environment_runner.close()


class DSNetMonobeast(ClearMonobeast):
    def __init__(self, model_flags, observation_space, action_spaces, policy_class):
        super().__init__(model_flags, observation_space, action_spaces, policy_class)
        # 蒸馏相关属性
        self.distiller = None
        self.current_node = None
        self.task_features = None

    def set_distillation_params(self, distiller, node, task_features):
        """设置蒸馏相关参数"""
        self.distiller = distiller
        self.current_node = node
        self.task_features = task_features

    def custom_loss(self, task_flags, model, initial_agent_state, batch, vtrace_returns, distillation_loss=0):
        """扩展损失函数支持蒸馏损失"""
        clear_loss, stats = super().custom_loss(
            task_flags, model, initial_agent_state, batch, vtrace_returns
        )
        model_outputs, unused_state = model(batch, task_flags.action_space_id, initial_agent_state)
        uncertainties = torch.abs(model_outputs['baseline'] - vtrace_returns.vs)
        uncertainty_loss = ((model_outputs['uncertainty'] - uncertainties.detach()) ** 2).mean()
        total_loss = self._model_flags.clear_loss_coeff * clear_loss

        total_loss = total_loss + self._model_flags.uncertainty_scale * uncertainty_loss

        # 在Monobeast中计算蒸馏损失
        distill_value = self._compute_distillation_loss(batch)
        self.logger.info(f"Adding distill loss to total loss: {distill_value.item()}")

        # 修改梯度钩子
        def grad_hook(grad):
            # 计算实际梯度范数
            grad_norm = grad.norm().item()
            self.logger.info(f"Distill grad norm: {grad_norm:.6f}")
            return grad

        if distill_value.requires_grad:
            distill_value.register_hook(grad_hook)

        total_loss = total_loss + distill_value

        # 记录统计信息
        stats["uncertainty_loss"] = uncertainty_loss.item()
        stats["distill_loss"] = distill_value.item()  # 确保为标量
        stats["total_loss"] = total_loss.item()

        return total_loss, stats

    def _compute_distillation_loss(self, batch):
        """在Monobeast中计算蒸馏损失"""
        # if self.distiller is not None and self.current_node is not None and self.task_features is not None:
        # try:
        # 在这里计算蒸馏损失
        # if self.distiller is not None
        if self.distiller is not None:
            distillation_loss = self.distiller.distill(
                self.current_node,
                batch,
                self.task_features
            )
            self.logger.info(f"Distillation loss computed: {distillation_loss.item()}")

            return distillation_loss
        else:
            return torch.tensor(0.0, requires_grad=True)




class DSNetNode(ImpalaPolicy):
    """
    Each node uses its own CLEAR-based Monobeast so that each has its own separate replay buffer
    """
    UNIQUE_ID_COUNTER = 0

    def __init__(self, config, observation_space, action_spaces, ensemble, unique_id=None):
        self.unique_id = self._get_unique_id(unique_id)
        node_config = copy.deepcopy(config)
        node_config.policy_unique_id = f"{config.policy_unique_id}node_{self.unique_id}"
        super().__init__(node_config, observation_space, action_spaces, impala_class=DSNetMonobeast)

        # # 获取设备
        # self.device = torch.device(
        #     config.device if hasattr(config, 'device') else "cuda" if torch.cuda.is_available() else "cpu")
        # # 移动模型到指定设备
        # self.impala_trainer.actor_model = self.impala_trainer.actor_model.to(self.device)
        # self.impala_trainer.learner_model = self.impala_trainer.learner_model.to(self.device)


        self._ensemble = ensemble
        self.usage_count = 0

        # 添加节点标识ID
        self.id = len(ensemble._nodes)  # 确保唯一标识

        if config.use_slow_critic:
            if config.slow_critic_ema_new_weight > 0:
                # Exponential moving average
                avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: \
                    (
                                1 - config.slow_critic_ema_new_weight) * averaged_model_parameter + config.slow_critic_ema_new_weight * model_parameter
            else:
                # None means equally weighted average
                avg_fn = None

            self.slow_critic = optim.swa_utils.AveragedModel(self.impala_trainer.actor_model,
                                                             avg_fn=avg_fn)  # Actor is on cpu, easier
            self.prototype = copy.deepcopy(self.slow_critic)
        else:
            self.prototype = copy.deepcopy(self.impala_trainer.actor_model)



    @classmethod
    def _get_unique_id(cls, unique_id=None):
        if unique_id is not None:
            cls.UNIQUE_ID_COUNTER = max(cls.UNIQUE_ID_COUNTER,
                                        unique_id + 1)  # The nodes might not be in numerical order, so max() it
        else:
            unique_id = cls.UNIQUE_ID_COUNTER
            cls.UNIQUE_ID_COUNTER += 1

        return unique_id

    def slow_critic_forward(self, obs, action_space_id):
        return self.slow_critic(obs, action_space_id)



    def policy_forward(self, inputs, action_space_id):
        """
        安全的前向传播函数，解决设备不匹配问题
        返回策略 logits 张量
        """
        # 1. 确定模型设备
        # device = self.nodes[0].impala_trainer._model_flags.device
        device = torch.device('cpu')
        model = self.impala_trainer.actor_model
        model = model.to(device)
        # except StopIteration:
            # logger.error("Model has no parameters to determine device")

        # 3. 确保输入数据在正确的设备上
        try:
            frame = inputs['frame']
            frame = frame.to(device)
            inputs['frame'] = frame
            # 处理其他必要输入字段
            for key in ['last_action', 'reward', 'done']:
                inputs[key] = inputs[key].to(device)

        except Exception as e:
            # logger.error(f"Input processing error: {str(e)}")
            traceback.print_exc()

        # 4. 调用模型前向传播
        try:
            # 确保使用同一个模型实例
            output, _ = model(inputs, action_space_id)

            policy_logits = output["policy_logits"]
            return policy_logits
        except Exception as e:
            # logger.error(f"Model forward pass failed: {str(e)}")
            traceback.print_exc()



    def prototype_forward(self, obs, action_space_id):
        return self.prototype(obs, action_space_id)

    def get_merge_metric(self):
        buffers = None
        if self._config.merge_by_batch:
            buffers = self.impala_trainer.get_batch_for_training(batch=None, store_for_loss=False,
                                                                 reuse_actor_indices=True,
                                                                 replay_entry_scale=self._config.merge_batch_scale)

        if buffers is None:
            buffers = self.impala_trainer._replay_buffers  # Will possibly include unfilled entries

        if self._config.merge_by_frame:
            metric = buffers['frame']
            if not isinstance(metric, torch.Tensor):  # The batch returns it pre-stacked, don't re-stack in that case
                metric = torch.stack(metric).float().mean(dim=0)
            metric = metric.float().mean(dim=0).mean(dim=0).mean(dim=0).view(-1)
        else:
            policies = buffers['policy_logits']
            if not isinstance(policies, torch.Tensor):
                policies = torch.stack(policies).mean(dim=0)
            metric = policies.mean(dim=0).mean(dim=0)

        return metric.cpu()





